#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Translates a source file using a translation model."""
# Speed up beam search a little bit more with memory consumption tradeoff
import gc
gc.disable()

import os
import sys
import time
import json
import inspect
import argparse
import importlib
import traceback
from multiprocessing import Process, Queue, cpu_count

from collections import OrderedDict

import numpy as np

from nmtpy.logger           import Logger
from nmtpy.metrics          import get_scorer
from nmtpy.nmtutils         import idx_to_sent
from nmtpy.sysutils         import *
from nmtpy.filters          import get_filter
from nmtpy.iterators.bitext import BiTextIterator
from nmtpy.defaults         import INT, FLOAT

import nmtpy.cleanup as cleanup

# Setup the logger
Logger.setup(timestamp=False)
log = Logger.get()

# This is to avoid thread explosion. Allow
# each process to use a single thread.
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

# Force CPU
os.environ["THEANO_FLAGS"] = "device=cpu,optimizer_including=local_remove_all_assert"

"""Worker process which does beam search."""
def translate_model(rqueue, wqueue, pid, models, beam_size, nbest, suppress_unks, get_att_alphas=False, seed=1234, mode="beamsearch"):
    try:
        # Get the method handle
        beam_search = models[0].beam_search
        gen_sample = models[0].gen_sample

        # Get function call string
        if mode == "beamsearch":
            f_inits     = [m.f_init for m in models]
            f_nexts     = [m.f_next for m in models]
            func_call = 'beam_search(list(data_dict.values()), f_inits, f_nexts, beam_size=beam_size, get_att_alphas=%s, suppress_unks=%s)' % (get_att_alphas, suppress_unks)

        elif mode in ["forced", "sample"]:
            func_call = 'gen_sample(data_dict)'

        elif mode == "argmax":
            func_call = 'gen_sample(data_dict, argmax=True)'

        while True:
            # Get a sample
            req = rqueue.get()

            # Unpack sample idx and data_dict
            sample_idx, data_dict = req[0], req[1]

            # Get the translation, its score and alignments
            trans, score, align = eval(func_call)

            # normalize scores according to sequence lengths
            score = score / np.array([len(s) for s in trans])

            # Sort the scores and take the best(s) idx(s)
            best_idxs = np.argsort(score)[:nbest]
            trans = np.array(trans)[best_idxs]

            # Check for attention weights
            if align is not None:
                align = np.array(align)[best_idxs]

            # Send response back
            wqueue.put((sample_idx, trans, score[best_idxs], align))
    except Exception as e:
        traceback.print_exc()
        # Signal error back
        wqueue.put(None)

class Translator(object):
    """Starts worker processes and waits for the results."""
    def __init__(self, args):
        self.beam_size = args.beam_size

        # Always lists provided by argparse (nargs:'+')
        self.src_files = args.src_files
        self.ref_files = args.ref_files

        # Collect processed source sentences in here
        # for further exporting to json
        self.export = args.export

        # Fetch other arguments
        self.nbest          = args.nbest
        self.seed           = args.seed
        self.mode           = args.decoder
        self.n_jobs         = args.n_jobs
        self.data_mode      = args.validmode

        self.models         = []
        self.model_files    = args.models
        self.model_options  = []
        self.n_models       = len(self.model_files)

        self.suppress_unks  = args.suppress_unks

        # Post-processing filters
        self.filters = []

        # Create worker process pool
        self.processes = [None] * self.n_jobs

    def set_model_options(self):
        for mfile in self.model_files:
            log.info('Initializing model %s' % os.path.basename(mfile))
            model_options = dict(np.load(mfile)['opts'].tolist())

            # Import the module
            self.__class = importlib.import_module("nmtpy.models.%s" % model_options['model_type']).Model

            # Create the model
            model = self.__class(seed=self.seed, logger=None, **model_options)
            model.load(mfile)
            model.set_dropout(False)
            model.build_sampler()

            self.models.append(model)
            self.model_options.append(model_options)

        # Check for post-processing filter
        if "filter" in self.model_options[0]:
            log.info("Hypotheses will be processed by the filters: '%s'" % model_options['filter'])
            filters = model_options['filter'].split(',')
            self.filters = [get_filter(f) for f in filters]

        # Get inverted dictionary from the model itself
        self.trg_idict = self.models[0].trg_idict

        #######################################################
        # Forced decoding/NMT rescoring for given src/trg pairs
        #######################################################
        if self.mode == "forced" and self.src_files and self.ref_files:
            log.info("Using only %s as reference file for forced decoding." % self.ref_files[0])
            self.iterator = BiTextIterator(
                                        batch_size=1,
                                        srcfile=self.src_files[0], srcdict=self.models[0].src_dict,
                                        trgfile=self.ref_files[0], trgdict=self.models[0].trg_dict,
                                        n_words_src=self.models[0].n_words_src,
                                        n_words_trg=self.models[0].n_words_trg,
                                        trg_name='y_true')
            self.iterator.read()

        #########################
        # Normal translation mode
        #########################
        else:
            if self.src_files is not None:
                # Pass the files to the model
                # NOTE: Not quite model agnostic way of doing things.
                self.models[0].data['valid_src'] = self.src_files[0]
                self.models[0].data['valid_trg'] = self.ref_files
                if 'valid_img' in self.models[0].data:
                    self.models[0].data['valid_img'] = self.src_files[1]

            # Initialize model's validation data iterator
            # NOTE: data_mode is for best-source-selection decoding for multimodal systems
            if 'data_mode' in inspect.getargspec(self.models[0].load_valid_data).args:
                self.models[0].load_valid_data(from_translate=True, data_mode=self.data_mode)
            else:
                self.models[0].load_valid_data(from_translate=True)

            # Set self.iterator to self.models[0].valid_iterator
            self.iterator = self.models[0].valid_iterator
            self.n_sentences = self.iterator.n_samples

            # Assume validation data encoded in the model
            if self.src_files is None:
                log.info("No test data given, assuming validation dataset.")

                self.src_files = listify(self.models[0].data['valid_src'])

                # User may provide another reference set in 'valid_trg_orig' for example
                # with compound splitting reverted so that we can compute
                # the metrics correctly.
                # NOTE: May be avoided by using filters on reference sentences.
                if "valid_trg_orig" in self.models[0].data:
                    self.ref_files = listify(self.models[0].data['valid_trg_orig'])
                else:
                    self.ref_files = listify(self.models[0].valid_ref_files)

            log.info('I will translate %d samples' % self.n_sentences)

        # Print information
        log.info("Source file(s)")
        for f in self.src_files:
            log.info("  %s" % f)

        if self.ref_files:
            log.info("Reference file(s)")
            for f in self.ref_files:
                log.info("  %s" % f)

    def start(self):
        # create input and output queues for processes
        write_queue = Queue()
        read_queue  = Queue()

        # Create processes
        for idx in range(self.n_jobs):
            self.processes[idx] = Process(target=translate_model,
                                          args=(write_queue, read_queue, idx, self.models, self.beam_size,
                                          self.nbest, self.suppress_unks, self.export,
                                          self.seed, self.mode))
            # Start process and register for cleanup
            self.processes[idx].start()
            cleanup.register_proc(self.processes[idx].pid)

        # Register the created processes to clean them after
        cleanup.register_handler(log)

        # Send data to worker processes
        for idx in range(self.n_sentences):
            write_queue.put((idx, next(self.iterator)))

        log.info("Distributed %d sentences to worker processes." % self.n_sentences)

        # Receive the results
        self.trans       = [None] * self.n_sentences
        self.scores      = [None] * self.n_sentences

        # Will be filled if --export is passed
        self.att_weights = [None] * self.n_sentences

        # Performance computation stuff
        start_time = per100_time = time.time()

        for i in range(self.n_sentences):
            # Get response from worker
            resp = read_queue.get()

            if resp is None:
                # Worker(s) failed
                log.info('One or more of the workers failed, exiting.')
                sys.exit(1)

            # This is the sample id of the processed sample
            sample_idx = resp[0]

            # Get the hypotheses, scores and attention weights if any
            hyps, self.scores[sample_idx], attw = resp[1:]

            # Did we receive attention weights from beam search?
            if attw is not None:
                self.att_weights[sample_idx] = attw[0]

            # Place the hypotheses into their relevant places
            self.trans[sample_idx] = [idx_to_sent(self.trg_idict, hyp) for hyp in hyps]

            # Print progress
            if (i+1) % 100 == 0:
                per100_time = time.time() - per100_time
                log.info("%4d/%d sentences completed (%.2f seconds)" % ((i+1), self.n_sentences, per100_time))
                per100_time = time.time()

        # Total time spent during beam search
        total_time      = time.time() - start_time
        sent_per_sec    = int(self.n_sentences / total_time)

        log.info("-------------------------------------------")
        log.info("Total decoding time: %3.3f seconds (%d sentences / sec)" % (total_time, sent_per_sec))

        # Compute word-based time statistics as well
        if self.nbest == 1:
            n_words         = float(sum([len(s[0].split(' ')) for s in self.trans]))
            word_per_sec    = int(n_words / total_time)
            log.info("~%d words / sec" % word_per_sec)

        # Stop workers
        for pidx in range(self.n_jobs):
            self.processes[pidx].terminate()

    def write_hyps(self, filename, dump_scores=False):
        # Apply post-processing filters like compound stitching
        for i in range(len(self.trans)):
            # List of hyps (length 1 if nbest==1) per source sentence
            for j in range(len(self.trans[i])):
                inp = self.trans[i][j]
                for filt in self.filters:
                    inp = filt(inp)
                self.trans[i][j] = inp

        # Write file
        with open(filename, 'w') as f:
            if self.mode == "forced" or dump_scores:
                # We have a single hypothesis and a score for each sentence
                for idx, (tr, sc) in enumerate(zip(self.trans, self.scores)):
                    f.write("%d ||| %s ||| %.6f\n" % (idx, tr[0], sc))

            elif self.nbest > 1:
                # We have n hypotheses and n scores for each sentence
                for idx, (trs, scs) in enumerate(zip(self.trans, self.scores)):
                    for tr, sc in zip(trs, scs):
                        f.write("%d ||| %s ||| %.6f\n" % (idx, tr, sc))
            else:
                # Prepare and dump
                self.hyps = [s[0] for s in self.trans]
                hyps = "\n".join(self.hyps) + "\n"
                f.write(hyps)

    def compute_metrics(self, hyp_file, scorers):
        """Computes evaluation metrics for the hypotheses."""
        results = {}
        for scorer in scorers:
            c = get_scorer(scorer)()
            score = c.compute(self.ref_files, hyp_file)
            results[scorer] = (str(score), score.score)
        return results

    def dump_json(self, filename):
        """Export decoding data into json for further visualization."""
        metadata = OrderedDict()
        metadata['models']    = [os.path.basename(m) for m in self.model_files]
        metadata['beam_size'] = self.beam_size

        srcs    = []
        refs    = []
        samples = []

        # Reset iterator
        self.iterator.rewind()
        for i in range(self.n_sentences):
            data = next(self.iterator)
            if 'x' in data:
                srcs.append(idx_to_sent(self.models[0].src_idict, data['x'].flatten()))

        # Save metadata
        data = {'metadata' : metadata}

        # Open reference files
        all_refs = [open(f).read().strip().split("\n") for f in self.ref_files]
        n_refs = len(all_refs)

        mult_source = False
        if len(all_refs[0]) != len(srcs):
            # Multiple sources given
            mult_source = True

        # Collect reference sentences
        for sidx in range(self.n_sentences):
            sidx = sidx if not mult_source else sidx % len(all_refs[0])
            refs.append([all_refs[i][sidx] for i in range(n_refs)])

        # Add sources, targets, and references
        # NOTE: att_weights are not cleared from BPE in terms of number of tokens
        for s, t, att in zip(srcs, self.trans, self.att_weights):
            sample = {'src' : s.split(' '), 'trg' : t[0].split(' '), 'ref' : refs.pop(0), 'att': att}
            samples.append(sample)
        data['data'] = samples

        # Export the JSON
        def _default(obj):
            if isinstance(obj, np.ndarray):
                return obj.tolist()

        with open(filename, 'w') as f:
            json.dump(data, f, default=_default)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog='nmt-translate')
    parser.add_argument('-j', '--n-jobs'        , type=int, default=8,      help="Number of processes (default: 8, 0: Auto)")
    parser.add_argument('-b', '--beam-size'     , type=int, default=12,     help="Beam size (only for beam-search)")
    parser.add_argument('-N', '--nbest'         , type=int, default=1,      help="N for N-best output (only for beam-search)")
    parser.add_argument('-r', '--seed'          , type=int, default=1234,   help="Random number seed for sampling mode (default: 1234)")

    parser.add_argument('-v', '--validmode'     , default='single',         help="Validation mode for WMT16 MMT Task2: all/pairs/single")
    parser.add_argument('-D', '--decoder'       , default='beamsearch',     choices=['beamsearch', 'argmax', 'sample', 'forced'], help="Decoding mode")

    parser.add_argument('-M', '--metrics'       , type=str, default='bleu', help="Comma separated list of metrics (bleu or bleu,meteor)")
    parser.add_argument('-o', '--saveto'        , type=str, default=None,   help="Output translations file (if not given, only metrics will be printed)")
    parser.add_argument('-e', '--export'        , action='store_true',      help="Export all decoding process to json for visualization")
    parser.add_argument('-s', '--score'         , action='store_true',      help="Print scores of each sentence even nbest == 1")
    parser.add_argument('-u', '--suppress-unks' , action='store_true',      help="Don't produce <unk>'s in beam search")

    parser.add_argument('-S', '--src-files'     , type=str, nargs='+', default=None, help="Source data(s) in order: text,image (default: validation set)")
    parser.add_argument('-R', '--ref-files'     , type=str, nargs='+', default=None, help="One or multiple reference files (default: validation set)")
    parser.add_argument('-m', '--models'        , nargs='+', required=True, help="Model files")

    args = parser.parse_args()

    if args.decoder == "forced" and (args.src_files is None or args.ref_files is None):
        print("Error: Forced decoding requires that you give src and ref files explicitly.")
        sys.exit(1)

    if args.n_jobs == 0:
        # Auto infer CPU number
        args.n_jobs = (cpu_count() / 2) - 1

    # Print some informations
    log.info("%d CPU processes - beam size = %2d" % (args.n_jobs, args.beam_size))
    log.info("Using %d model(s) for translation" % len(args.models))

    # Create translator object
    translator = Translator(args)
    translator.set_model_options()
    translator.start()

    out_file = args.saveto
    if not args.saveto:
        # Override this if given
        args.score = False
        hypf = get_temp_file(suffix=".nbest_hyps")
        out_file = hypf.name
        hypf.close()

    # Export attentional informations if -o and -e are given
    if args.export and args.saveto:
        translator.dump_json("%s.json" % out_file)

    # Dump hypotheses
    translator.write_hyps(out_file, args.score)

    # No need to compute metrics with nbest style files
    if args.decoder != "forced" and args.nbest == 1 \
            and translator.ref_files and not args.score:
        # Compute metrics and print them. Printing is necessary for nmt-train
        print(translator.compute_metrics(out_file, args.metrics.split(",")))

    # Report success
    sys.exit(0)
